#include <stdint.h>
#include <pthread.h>
#include <assert.h>
#include <stdio.h>
#include <memory.h>
#include <sys/shm.h>
#include <numaif.h>

#include "hash_table.h"
#include "dbg.h"

#define M_HUGE_PAGE_SIZE        2097152
#define N_HUGE_PAGE_RETRY_STEP  64
#define HUGEPAGE_PROC_FILE  "/proc/sys/vm/nr_hugepages"

typedef struct
{
        void *vaddr;
        void *paddr;
}mpage_proto_t;

typedef struct
{
        hashtable_t             shm_ht;
        pthread_mutex_t         mutex;
        uint64_t                total_size;
        uint64_t                nr_hugepages;
        uint64_t                nr_reserved_maybe;
}global_mem_t;

typedef struct
{
        int shm_id;
        void *vaddr;
        uint64_t size;
}global_mem_entry_t;

static global_mem_t g_global_mem;

int (*g_scale_mem_alloc_callback)(void *vaddr, size_t len) = NULL;
int (*g_scale_mem_free_callback)(void *vaddr, size_t len) = NULL;

/*first level arbstract, alloc/free single huage pages, into ring buf, can be replaced with shm..*/
//skipped............

static uint32_t mem_hash_key_func(const void *k)
{
        const uint32_t *idx = k;
        return (*idx);
}

static int mem_hash_cmp_func(const void *s1, const void *s2)
{
        const global_mem_entry_t *e1 = s1;
        //const uint32_t *idx = s2;

        if( e1->vaddr < s2)
                return -1;
        else if (e1->vaddr > s2)
                return 1;

        return 0;
}

void scale_mem_set_callback(int (*mem_alloc_callback)(void *vaddr, size_t len),  int (*mem_free_callback)(void *vaddr, size_t len))
{
        g_scale_mem_alloc_callback = mem_alloc_callback;
        g_scale_mem_free_callback = mem_free_callback;
}

static void __attribute__((constructor)) _init(void)
{
        memset(&g_global_mem, 0, sizeof(g_global_mem));

        g_global_mem.shm_ht = hash_create_table(mem_hash_cmp_func, mem_hash_key_func, "global memory");
        if (pthread_mutex_init(&g_global_mem.mutex, NULL)) {
                assert(0);
        }
}

static int adjust_nr_hugepages(int total_hugepage_num)
{
        int ret, fd;
        char tmp[16];
        int count = 0;

        fd = open(HUGEPAGE_PROC_FILE, O_RDWR);
        if (fd < 0) {
                ret = errno;
                DERROR("open hugepage proc file [%s]failed :%s\n", HUGEPAGE_PROC_FILE, strerror(ret));
                GOTO(err_ret, ret);
        }

        /*ret = read(fd, tmp, 16);
        if (ret < 0) {
                ret = errno;
                DERROR("read hugepage proc file [%s]failed :%s\n", HUGEPAGE_PROC_FILE, strerror(ret));
                GOTO(err_fd, ret);
        }

        //count = atoi(tmp);

        count = gloconf.huge_page_reserved;*/
        
        DINFO("alloc hugepages %d, previous is: %d\n", total_hugepage_num, count);
        snprintf(tmp, 8, "%d", total_hugepage_num + count + 32);
        ret = write(fd, tmp, strlen(tmp));
        if (ret < 0) {
                ret = errno;
                DERROR("write hugepage proc file [%s]failed :%s\n", HUGEPAGE_PROC_FILE, strerror(ret));
                GOTO(err_fd, ret);
        }

        close(fd);

        return 0;

err_fd:
        close(fd);
err_ret:
        return ret;

}

/*second arbstract, alloc/free N huge pages with continous virtual address*/
void *mem_huge_alloc_base(size_t sz, int *shmid)
{
        int shmemid;
        uint8_t *buf;

        /* allocate memory */
        shmemid = shmget(IPC_PRIVATE, sz, SHM_HUGETLB | IPC_CREAT | SHM_R | SHM_W);

        if (shmemid < 0)
        {
                DERROR("shmget rdma pool sz:%d failed\n", sz);

                return NULL;
        }

        /* get pointer to allocated memory */
        buf = shmat(shmemid, NULL, 0);

        if (buf == (void *)-1)
        {
                DERROR("Shared memory attach failure (errno=%d %m)", errno);
                shmctl(shmemid, IPC_RMID, NULL);
                
                return NULL;
        }

        /* mark 'to be destroyed' when process detaches from shmem segment
           this will clear the HugePage resources even if process if killed not nicely.
           From checking shmctl man page it is unlikely that it will fail here. */
        if (shmctl(shmemid, IPC_RMID, NULL))
        {
                DERROR("Shared memory contrl mark 'to be destroyed' failed (errno=%d %m)", errno);
        }

        DBUG("Allocated huge page sz:%d\n", sz);
        *shmid = shmemid;

        return buf;
}

void mem_huge_free_base(uint8_t *pool_buf, int shmid)
{
        struct shmid_ds ds;
        int ret = shmctl(shmid, IPC_STAT, &ds);

        assert(ret == 0);

        DINFO("mem free pool, size: %lu\r\n", ds.shm_segsz);

        if (shmdt(pool_buf) != 0) {
                DERROR("shmem detach failure (errno=%d %m)", errno);
        }

        if (g_scale_mem_free_callback)
                g_scale_mem_free_callback(pool_buf, ds.shm_segsz);
}

void *scale_mem_zmalloc(size_t sz, int socketid)
{
        int ret;
        int shmid;
        void *ptr;
        uint64_t total_hugepages;

again:
        pthread_mutex_lock(&g_global_mem.mutex);

        total_hugepages = round_up(sz, M_HUGE_PAGE_SIZE) + g_global_mem.total_size;
        total_hugepages /= M_HUGE_PAGE_SIZE;

        if(g_global_mem.nr_hugepages < total_hugepages) {
                adjust_nr_hugepages(total_hugepages + g_global_mem.nr_reserved_maybe);
        }

        pthread_mutex_unlock(&g_global_mem.mutex);

        ptr = mem_huge_alloc_base(sz, &shmid);
        if(!ptr) {
                g_global_mem.nr_reserved_maybe += N_HUGE_PAGE_RETRY_STEP;
                goto again;
        }
        else
                DINFO("reserved huge page maybe: %lu\r\n", g_global_mem.nr_reserved_maybe);

        assert(sz % M_HUGE_PAGE_SIZE == 0);

        for(int i=0;i<sz / M_HUGE_PAGE_SIZE;i++)
        {
                void *page = ptr + i * M_HUGE_PAGE_SIZE;

                unsigned long nodemask;

                nodemask = 1 << socketid;
                ret = mbind(page, M_HUGE_PAGE_SIZE, MPOL_PREFERRED, &nodemask, 3, 0);
                if(ret)
                        DWARN("mbind failed with error, ret=%d, errno=%d\n", ret, errno);

                memset(page, 0, M_HUGE_PAGE_SIZE);
        }

        if(g_scale_mem_alloc_callback)
                g_scale_mem_alloc_callback(ptr, sz);

        pthread_mutex_lock(&g_global_mem.mutex);

        global_mem_entry_t *entry = malloc(sizeof(global_mem_entry_t));
        entry->vaddr = ptr;
        entry->size = sz;
        entry->shm_id = shmid;
        ret = hash_table_insert(g_global_mem.shm_ht, (void *)entry, &entry->shm_id, 0);
        if (unlikely(ret)) {
                assert(0);
        }

        g_global_mem.total_size += round_up(sz, M_HUGE_PAGE_SIZE);
        pthread_mutex_unlock(&g_global_mem.mutex);

        return ptr;
}

void scale_mem_zfree(void *ptr)
{
        int ret;
        global_mem_entry_t *entry = NULL;

        pthread_mutex_lock(&g_global_mem.mutex);

        ret = hash_table_remove(g_global_mem.shm_ht, (void *)&ptr, (void **)&entry);
        assert(!ret);
        assert(entry);

        if(g_scale_mem_free_callback)
                g_scale_mem_free_callback(entry->vaddr, entry->size);

        g_global_mem.total_size -= round_up(entry->size, M_HUGE_PAGE_SIZE);

        pthread_mutex_unlock(&g_global_mem.mutex);

        mem_huge_free_base(entry->vaddr, entry->shm_id);
        free(entry);
}